{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GAN介绍\n",
    "生成对抗网络（Generative Adversarial Networks，GAN）最早由 Ian Goodfellow 在 2014 年提出，是目前深度学习领域最具潜力的研究成果之一。它的核心思想是：同时训练两个相互协作、同时又相互竞争的深度神经网络（一个称为生成器 Generator，另一个称为判别器 Discriminator）来处理无监督学习的相关问题。  \n",
    "\n",
    "通常，我们会用下面这个例子来说明 GAN 的原理：将警察视为判别器，制造假币的犯罪分子视为生成器。一开始，犯罪分子会首先向警察展示一张假币。警察识别出该假币，并向犯罪分子反馈哪些地方是假的。接着，根据警察的反馈，犯罪分子改进工艺，制作一张更逼真的假币给警方检查。这时警方再反馈，犯罪分子再改进工艺。不断重复这一过程，直到警察识别不出真假，那么模型就训练成功了。  \n",
    "\n",
    "GAN的变体非常多，我们就以深度卷积生成对抗网络（Deep Convolutional GAN，DCGAN）为例，自动生成 MNIST 手写体数字。\n",
    "\n",
    "### 判别器：\n",
    "判别器的作用是判断一个模型生成的图像和真实图像比，有多逼真。它的基本结构就是如下图所示的卷积神经网络（Convolutional Neural Network，CNN）。对于 MNIST 数据集来说，模型输入是一个 28x28 像素的单通道图像。Sigmoid 函数的输出值在 0-1 之间，表示图像真实度的概率，其中 0 表示肯定是假的，1 表示肯定是真的。与典型的 CNN 结构相比，这里去掉了层之间的 max-pooling。这里每个 CNN 层都以 LeakyReLU 为激活函数。而且为了防止过拟合，层之间的 dropout 值均被设置在 0.4-0.7 之间，模型结构如下：\n",
    "<center><img src=\"images/Discriminator.jpg\" alt=\"FAO\" width=\"500\"></center> \n",
    "ReLU激活函数极为f(x)=alpha * x for x < 0, f(x) = x for x>=0。alpha是一个小的非零数。\n",
    "<center><img src=\"images/LeakyRelu.png\" alt=\"FAO\" width=\"200\"></center>\n",
    "\n",
    "### 生成器：\n",
    "生成器的作用是合成假的图像，其基本机构如下图所示。图中，我们使用了卷积的倒数，即[转置卷积（transposed convolution）](https://github.com/vdumoulin/conv_arithmetic)，从 100 维的噪声（满足 -1 至 1 之间的均匀分布）中生成了假图像。这里我们采用了模型前三层之间的上采样来合成更逼真的手写图像。在层与层之间，我们采用了批量归一化的方法来平稳化训练过程。以 ReLU 函数为每一层结构之后的激活函数。最后一层 Sigmoid 函数输出最后的假图像。第一层设置了 0.3-0.5 之间的 dropout 值来防止过拟合。\n",
    "<center><img src=\"images/Generator.jpg\" alt=\"FAO\" width=\"500\"></center> \n",
    "批量正则化：\n",
    "<center><img src=\"images/batch normalization.png\" alt=\"FAO\" width=\"500\"></center>\n",
    "\n",
    "### GAN应用\n",
    "[1.图像生成](http://make.girls.moe)  \n",
    "2.向量空间运算\n",
    "<center><img src=\"images/GAN1.jpg\" alt=\"FAO\" width=\"500\"></center>\n",
    "3.文本转图像\n",
    "<center><img src=\"images/GAN2.jpg\" alt=\"FAO\" width=\"500\"></center>\n",
    "4.超分辨率\n",
    "<center><img src=\"images/GAN4.jpg\" alt=\"FAO\" width=\"500\"></center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_9\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_8 (Conv2D)            (None, 14, 14, 64)        1664      \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_8 (LeakyReLU)    (None, 14, 14, 64)        0         \n",
      "_________________________________________________________________\n",
      "dropout_10 (Dropout)         (None, 14, 14, 64)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_9 (Conv2D)            (None, 7, 7, 128)         204928    \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_9 (LeakyReLU)    (None, 7, 7, 128)         0         \n",
      "_________________________________________________________________\n",
      "dropout_11 (Dropout)         (None, 7, 7, 128)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_10 (Conv2D)           (None, 4, 4, 256)         819456    \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_10 (LeakyReLU)   (None, 4, 4, 256)         0         \n",
      "_________________________________________________________________\n",
      "dropout_12 (Dropout)         (None, 4, 4, 256)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_11 (Conv2D)           (None, 4, 4, 512)         3277312   \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_11 (LeakyReLU)   (None, 4, 4, 512)         0         \n",
      "_________________________________________________________________\n",
      "dropout_13 (Dropout)         (None, 4, 4, 512)         0         \n",
      "_________________________________________________________________\n",
      "flatten_2 (Flatten)          (None, 8192)              0         \n",
      "_________________________________________________________________\n",
      "dense_4 (Dense)              (None, 1)                 8193      \n",
      "_________________________________________________________________\n",
      "activation_12 (Activation)   (None, 1)                 0         \n",
      "=================================================================\n",
      "Total params: 4,311,553\n",
      "Trainable params: 4,311,553\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"sequential_11\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_5 (Dense)              (None, 12544)             1266944   \n",
      "_________________________________________________________________\n",
      "batch_normalization_8 (Batch (None, 12544)             50176     \n",
      "_________________________________________________________________\n",
      "activation_13 (Activation)   (None, 12544)             0         \n",
      "_________________________________________________________________\n",
      "reshape_2 (Reshape)          (None, 7, 7, 256)         0         \n",
      "_________________________________________________________________\n",
      "dropout_14 (Dropout)         (None, 7, 7, 256)         0         \n",
      "_________________________________________________________________\n",
      "up_sampling2d_4 (UpSampling2 (None, 14, 14, 256)       0         \n",
      "_________________________________________________________________\n",
      "conv2d_transpose_8 (Conv2DTr (None, 14, 14, 128)       819328    \n",
      "_________________________________________________________________\n",
      "batch_normalization_9 (Batch (None, 14, 14, 128)       512       \n",
      "_________________________________________________________________\n",
      "activation_14 (Activation)   (None, 14, 14, 128)       0         \n",
      "_________________________________________________________________\n",
      "up_sampling2d_5 (UpSampling2 (None, 28, 28, 128)       0         \n",
      "_________________________________________________________________\n",
      "conv2d_transpose_9 (Conv2DTr (None, 28, 28, 64)        204864    \n",
      "_________________________________________________________________\n",
      "batch_normalization_10 (Batc (None, 28, 28, 64)        256       \n",
      "_________________________________________________________________\n",
      "activation_15 (Activation)   (None, 28, 28, 64)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_transpose_10 (Conv2DT (None, 28, 28, 32)        51232     \n",
      "_________________________________________________________________\n",
      "batch_normalization_11 (Batc (None, 28, 28, 32)        128       \n",
      "_________________________________________________________________\n",
      "activation_16 (Activation)   (None, 28, 28, 32)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_transpose_11 (Conv2DT (None, 28, 28, 1)         801       \n",
      "_________________________________________________________________\n",
      "activation_17 (Activation)   (None, 28, 28, 1)         0         \n",
      "=================================================================\n",
      "Total params: 2,394,241\n",
      "Trainable params: 2,368,705\n",
      "Non-trainable params: 25,536\n",
      "_________________________________________________________________\n",
      "0: [D loss: 0.693271, acc: 0.501953]  [A loss: 1.294270, acc: 0.000000]\n",
      "1: [D loss: 0.651373, acc: 0.837891]  [A loss: 1.246781, acc: 0.000000]\n",
      "2: [D loss: 0.553445, acc: 1.000000]  [A loss: 1.405086, acc: 0.000000]\n",
      "3: [D loss: 0.426323, acc: 0.847656]  [A loss: 2.246035, acc: 0.000000]\n",
      "4: [D loss: 0.355004, acc: 0.943359]  [A loss: 0.647344, acc: 0.628906]\n",
      "5: [D loss: 0.153188, acc: 1.000000]  [A loss: 0.356264, acc: 0.941406]\n",
      "6: [D loss: 0.071162, acc: 0.996094]  [A loss: 0.022109, acc: 1.000000]\n",
      "7: [D loss: 0.042331, acc: 0.990234]  [A loss: 0.004800, acc: 1.000000]\n",
      "8: [D loss: 0.038330, acc: 0.994141]  [A loss: 0.003594, acc: 1.000000]\n",
      "9: [D loss: 0.030748, acc: 1.000000]  [A loss: 0.003043, acc: 1.000000]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from tensorflow.keras.datasets import mnist\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Activation, Flatten, Reshape\n",
    "from tensorflow.keras.layers import Conv2D, Conv2DTranspose, UpSampling2D\n",
    "from tensorflow.keras.layers import LeakyReLU, Dropout\n",
    "from tensorflow.keras.layers import BatchNormalization\n",
    "from tensorflow.keras.optimizers import RMSprop\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "class DCGAN(object):\n",
    "    def __init__(self, img_rows=28, img_cols=28, channel=1):\n",
    "        # 初始化图片的行列通道数\n",
    "        self.img_rows = img_rows\n",
    "        self.img_cols = img_cols\n",
    "        self.channel = channel\n",
    "        self.D = None   # discriminator 判别器\n",
    "        self.G = None   # generator 生成器\n",
    "        self.AM = None  # adversarial model 对抗模型\n",
    "        self.DM = None  # discriminator model 判别模型\n",
    "\n",
    "    # 判别模型\n",
    "    def discriminator(self):\n",
    "        if self.D:\n",
    "            return self.D\n",
    "        self.D = Sequential()\n",
    "        # 定义通道数64\n",
    "        depth = 64\n",
    "        # dropout系数\n",
    "        dropout = 0.4\n",
    "        # 输入28*28*1\n",
    "        input_shape = (self.img_rows, self.img_cols, self.channel)\n",
    "        # 输出14*14*64\n",
    "        self.D.add(Conv2D(depth*1, 5, strides=2, input_shape=input_shape, padding='same'))\n",
    "        self.D.add(LeakyReLU(alpha=0.2))\n",
    "        self.D.add(Dropout(dropout))\n",
    "        # 输出7*7*128\n",
    "        self.D.add(Conv2D(depth*2, 5, strides=2, padding='same'))\n",
    "        self.D.add(LeakyReLU(alpha=0.2))\n",
    "        self.D.add(Dropout(dropout))\n",
    "        # 输出4*4*256\n",
    "        self.D.add(Conv2D(depth*4, 5, strides=2, padding='same'))\n",
    "        self.D.add(LeakyReLU(alpha=0.2))\n",
    "        self.D.add(Dropout(dropout))\n",
    "        # 输出4*4*512\n",
    "        self.D.add(Conv2D(depth*8, 5, strides=1, padding='same'))\n",
    "        self.D.add(LeakyReLU(alpha=0.2))\n",
    "        self.D.add(Dropout(dropout))\n",
    "\n",
    "        # 全连接层\n",
    "        self.D.add(Flatten())\n",
    "        self.D.add(Dense(1))\n",
    "        self.D.add(Activation('sigmoid'))\n",
    "        self.D.summary()\n",
    "        return self.D\n",
    "\n",
    "    # 生成模型\n",
    "    def generator(self):\n",
    "        if self.G:\n",
    "            return self.G\n",
    "        self.G = Sequential()\n",
    "        # dropout系数\n",
    "        dropout = 0.4\n",
    "        # 通道数256\n",
    "        depth = 64*4\n",
    "        # 初始平面大小设置\n",
    "        dim = 7\n",
    "        # 全连接层，100个的随机噪声数据，7*7*256个神经网络\n",
    "        self.G.add(Dense(dim*dim*depth, input_dim=100))\n",
    "        self.G.add(BatchNormalization(momentum=0.9))\n",
    "        self.G.add(Activation('relu'))\n",
    "        # 把1维的向量变成3维数据(7,7,256)\n",
    "        self.G.add(Reshape((dim, dim, depth)))\n",
    "        self.G.add(Dropout(dropout))\n",
    "\n",
    "\n",
    "        # 用法和 MaxPooling2D 基本相反，比如：UpSampling2D(size=(2, 2))\n",
    "        # 就相当于将输入图片的长宽各拉伸一倍，整个图片被放大了\n",
    "        # 上采样，采样后得到数据格式(14,14,256)\n",
    "        self.G.add(UpSampling2D()) \n",
    "        # 转置卷积，得到数据格式(14,14,128) \n",
    "        self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same')) \n",
    "        self.G.add(BatchNormalization(momentum=0.9))\n",
    "        self.G.add(Activation('relu'))\n",
    "\n",
    "        # 上采样，采样后得到数据格式(28,28,128)\n",
    "        self.G.add(UpSampling2D()) \n",
    "        # 转置卷积，得到数据格式(28,28,64) \n",
    "        self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same'))\n",
    "        self.G.add(BatchNormalization(momentum=0.9))\n",
    "        self.G.add(Activation('relu'))\n",
    "\n",
    "        # 转置卷积，得到数据格式(28,28,32) \n",
    "        self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same')) \n",
    "        self.G.add(BatchNormalization(momentum=0.9))\n",
    "        self.G.add(Activation('relu'))\n",
    "\n",
    "        # 转置卷积，得到数据格式(28,28,1) \n",
    "        self.G.add(Conv2DTranspose(1, 5, padding='same'))\n",
    "        self.G.add(Activation('sigmoid'))\n",
    "        self.G.summary()\n",
    "        return self.G\n",
    "\n",
    "    # 定义判别模型\n",
    "    def discriminator_model(self):\n",
    "        if self.DM:\n",
    "            return self.DM\n",
    "        # 定义优化器\n",
    "        optimizer = RMSprop(lr=0.0002, decay=6e-8)\n",
    "        # 构建模型\n",
    "        self.DM = Sequential()\n",
    "        self.DM.add(self.discriminator())\n",
    "        self.DM.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])\n",
    "        return self.DM\n",
    "\n",
    "    # 定义对抗模型\n",
    "    def adversarial_model(self):\n",
    "        if self.AM:\n",
    "            return self.AM\n",
    "        # 定义优化器\n",
    "        optimizer = RMSprop(lr=0.0001, decay=3e-8)\n",
    "        # 构建模型\n",
    "        self.AM = Sequential()\n",
    "        # 生成器\n",
    "        self.AM.add(self.generator())\n",
    "        # 判别器\n",
    "        self.AM.add(self.discriminator())\n",
    "        self.AM.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])\n",
    "        return self.AM\n",
    "\n",
    "class MNIST_DCGAN(object):\n",
    "    def __init__(self):\n",
    "        # 图片的行数\n",
    "        self.img_rows = 28\n",
    "        # 图片的列数\n",
    "        self.img_cols = 28\n",
    "        # 图片的通道数\n",
    "        self.channel = 1\n",
    "\n",
    "        # 载入数据\n",
    "        (x_train,y_train),(x_test,y_test) = mnist.load_data()\n",
    "        # (60000,28,28)\n",
    "        self.x_train = x_train/255.0\n",
    "        # 改变数据格式(samples, rows, cols, channel)(60000,28,28,1)\n",
    "        self.x_train = self.x_train.reshape(-1, self.img_rows, self.img_cols, 1).astype(np.float32)\n",
    "\n",
    "        # 实例化DCGAN类\n",
    "        self.DCGAN = DCGAN()\n",
    "        # 定义判别器模型\n",
    "        self.discriminator =  self.DCGAN.discriminator_model()\n",
    "        # 定义对抗模型\n",
    "        self.adversarial = self.DCGAN.adversarial_model()\n",
    "        # 定义生成器\n",
    "        self.generator = self.DCGAN.generator()\n",
    "\n",
    "    # 训练模型\n",
    "    def train(self, train_steps=20, batch_size=256, save_interval=0):\n",
    "        noise_input = None\n",
    "        if save_interval>0:\n",
    "            # 生成16个100维的噪声数据\n",
    "            noise_input = np.random.uniform(-1.0, 1.0, size=[16, 100])\n",
    "        for i in range(train_steps):\n",
    "        # 训练判别器，提升判别能力\n",
    "            # 随机得到一个batch的图片数据\n",
    "            images_train = self.x_train[np.random.randint(0, self.x_train.shape[0], size=batch_size), :, :, :]\n",
    "            # 随机生成一个batch的噪声数据\n",
    "            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])\n",
    "            # 生成伪造的图片数据\n",
    "            images_fake = self.generator.predict(noise)\n",
    "            # 合并一个batch的真实图片和一个batch的伪造图片\n",
    "            x = np.concatenate((images_train, images_fake))\n",
    "            # 定义标签，真实数据的标签为1，伪造数据的标签为0\n",
    "            y = np.ones([2*batch_size, 1])\n",
    "            y[batch_size:, :] = 0\n",
    "            # 把数据放到判别器中进行判断\n",
    "            d_loss = self.discriminator.train_on_batch(x, y)\n",
    "        \n",
    "        # 训练对抗模型，提升生成器的造假能力\n",
    "            # 标签都定义为1\n",
    "            y = np.ones([batch_size, 1])\n",
    "            # 生成一个batch的噪声数据\n",
    "            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])\n",
    "            # 训练对抗模型\n",
    "            a_loss = self.adversarial.train_on_batch(noise, y)\n",
    "            # 打印判别器的loss和准确率，以及对抗模型的loss和准确率\n",
    "            log_mesg = \"%d: [D loss: %f, acc: %f]\" % (i, d_loss[0], d_loss[1])\n",
    "            log_mesg = \"%s  [A loss: %f, acc: %f]\" % (log_mesg, a_loss[0], a_loss[1])\n",
    "            print(log_mesg)\n",
    "            # 如果需要保存图片\n",
    "            if save_interval>0:\n",
    "                # 每save_interval次保存一次\n",
    "                if (i+1)%save_interval==0:\n",
    "                    self.plot_images(save2file=True, samples=noise_input.shape[0], noise=noise_input, step=(i+1))\n",
    "\n",
    "    # 保存图片\n",
    "    def plot_images(self, save2file=False, fake=True, samples=16, noise=None, step=0):\n",
    "        filename = 'mnist.png'\n",
    "        if fake:\n",
    "            if noise is None:\n",
    "                noise = np.random.uniform(-1.0, 1.0, size=[samples, 100])\n",
    "            else:\n",
    "                filename = \"mnist_%d.png\" % step\n",
    "            # 生成伪造的图片数据\n",
    "            images = self.generator.predict(noise)\n",
    "        else:\n",
    "            # 获得真实图片数据\n",
    "            i = np.random.randint(0, self.x_train.shape[0], samples)\n",
    "            images = self.x_train[i, :, :, :]\n",
    "\n",
    "        # 设置图片大小\n",
    "        plt.figure(figsize=(10,10))\n",
    "        # 生成16张图片\n",
    "        for i in range(images.shape[0]):\n",
    "            plt.subplot(4, 4, i+1)\n",
    "            # 获取一个张图片数据\n",
    "            image = images[i, :, :, :]\n",
    "            # 变成2维的图片\n",
    "            image = np.reshape(image, [self.img_rows, self.img_cols])\n",
    "            # 显示灰度图片\n",
    "            plt.imshow(image, cmap='gray')\n",
    "            # 不显示坐标轴\n",
    "            plt.axis('off')\n",
    "        # 保存图片\n",
    "        if save2file:\n",
    "            plt.savefig(filename)\n",
    "            plt.close('all')\n",
    "        # 不保存的话就显示图片\n",
    "        else:\n",
    "            plt.show()\n",
    "\n",
    "            \n",
    "# 实例化网络的类\n",
    "mnist_dcgan = MNIST_DCGAN()\n",
    "# 训练模型\n",
    "mnist_dcgan.train(train_steps=10, batch_size=256, save_interval=500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_dcgan.plot_images(fake=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_dcgan.plot_images(fake=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_dcgan.generator.save('generator.h5')\n",
    "mnist_dcgan.discriminator.save('discriminator.h5')\n",
    "mnist_dcgan.adversarial.save('adversarial.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
